{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4cb58299",
   "metadata": {},
   "outputs": [],
   "source": [
    "from modeling import MambaLMHeadModel, MambaConfig\n",
    "from transformers import PretrainedConfig, AutoTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8e2ca293",
   "metadata": {},
   "outputs": [],
   "source": [
    "MambaLMHeadModel.register_for_auto_class()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "49c595a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('config-1.4b.json') as fopen:\n",
    "    config = json.load(fopen)\n",
    "    config['hidden_size'] = config['d_model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "14fa359a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MambaConfig {\n",
       "  \"d_model\": 2048,\n",
       "  \"fused_add_norm\": true,\n",
       "  \"hidden_size\": 2048,\n",
       "  \"model_type\": \"mamba\",\n",
       "  \"n_layer\": 48,\n",
       "  \"pad_vocab_size_multiple\": 8,\n",
       "  \"residual_in_fp32\": true,\n",
       "  \"rms_norm\": true,\n",
       "  \"ssm_cfg\": {},\n",
       "  \"transformers_version\": \"4.35.2\",\n",
       "  \"vocab_size\": 32000\n",
       "}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config = MambaConfig(**{**config, 'vocab_size': 32000})\n",
    "config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a383738a",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = MambaLMHeadModel(config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4b0815e0",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Removed shared tensor {'lm_head.weight'} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3537da1624e341968f79979aad14104b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00002-of-00002.safetensors:   0%|          | 0.00/351M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c702e18880b54dc5800f1bbd547dd58c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cde81cea5fd948a9ba3ea35d779143ee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Upload 2 LFS files:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "CommitInfo(commit_url='https://huggingface.co/huseinzol05/dummy-mamba-1.4b/commit/61e0ee5a11fd6fa9955636dba8f4410bdc65bf76', commit_message='Upload model', commit_description='', oid='61e0ee5a11fd6fa9955636dba8f4410bdc65bf76', pr_url=None, pr_revision=None, pr_num=None)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.push_to_hub('huseinzol05/dummy-mamba-1.4b')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d92631c4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8582f4ef7024189a5d9fd35f6856de3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading config.json:   0%|          | 0.00/412 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a615cc73f6a64f4b8cd6f4b44b4f7fb8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e7f18eb3a5584475a06a0c91362fc991",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ffd331fb6631491a97a527039ef78869",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)of-00002.safetensors:   0%|          | 0.00/351M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7a6260bdb6c454b8cfbde5f9997d3e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = MambaLMHeadModel.from_pretrained('huseinzol05/dummy-mamba-1.4b')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
